{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "HyperNeRF Training.ipynb",
      "private_outputs": true,
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EZ_wkNVdTz-C"
      },
      "source": [
        "# Let's train HyperNeRF!\n",
        "\n",
        "**Author**: [Keunhong Park](https://keunhong.com)\n",
        "\n",
        "[[Project Page](https://hypernerf.github.io)]\n",
        "[[Paper](https://arxiv.org/abs/2106.13228)]\n",
        "[[GitHub](https://github.com/google/hypernerf)]\n",
        "\n",
        "This notebook provides an demo for training HyperNeRF.\n",
        "\n",
        "### Instructions\n",
        "\n",
        "1. Convert a video into our dataset format using the Nerfies [dataset processing notebook](https://colab.sandbox.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb).\n",
        "2. Set the `data_dir` below to where you saved the dataset.\n",
        "3. Come back to this notebook to train HyperNeRF.\n",
        "\n",
        "\n",
        "### Notes\n",
        " * To accomodate the limited compute power of Colab runtimes, this notebook defaults to a \"toy\" version of our method. The number of samples have been reduced and the elastic regularization turned off.\n",
        "\n",
        " * To train a high-quality model, please look at the CLI options we provide in the [Github repository](https://github.com/google/hypernerf).\n",
        "\n",
        "\n",
        "\n",
        " * Please report issues on the [GitHub issue tracker](https://github.com/google/hypernerf/issues).\n",
        "\n",
        "\n",
        "If you find this work useful, please consider citing:\n",
        "```bibtex\n",
        "@article{park2021hypernerf\n",
        "  author    = {Park, Keunhong and Sinha, Utkarsh and Hedman, Peter and Barron, Jonathan T. and Bouaziz, Sofien and Goldman, Dan B and Martin-Brualla, Ricardo and Seitz, Steven M.},\n",
        "  title     = {HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields},\n",
        "  journal   = {arXiv preprint arXiv:2106.13228},\n",
        "  year      = {2021},\n",
        "}\n",
        "```\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OlW1gF_djH6H"
      },
      "source": [
        "## Environment Setup"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "I6Jbspl7TnIX"
      },
      "source": [
        "!pip install flax immutabledict mediapy\n",
        "!pip install --upgrade git+https://github.com/google/hypernerf"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zGJux-m5Xp3Z",
        "cellView": "form"
      },
      "source": [
        "# @title Configure notebook runtime\n",
        "# @markdown If you would like to use a GPU runtime instead, change the runtime type by going to `Runtime > Change runtime type`. \n",
        "# @markdown You will have to use a smaller batch size on GPU.\n",
        "\n",
        "runtime_type = 'tpu'  # @param ['gpu', 'tpu']\n",
        "if runtime_type == 'tpu':\n",
        "  import jax.tools.colab_tpu\n",
        "  jax.tools.colab_tpu.setup_tpu()\n",
        "\n",
        "print('Detected Devices:', jax.devices())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "afUtLfRWULEi",
        "cellView": "form"
      },
      "source": [
        "# @title Mount Google Drive\n",
        "# @markdown Mount Google Drive onto `/content/gdrive`. You can skip this if running locally.\n",
        "\n",
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ENOfbG3AkcVN",
        "cellView": "form"
      },
      "source": [
        "# @title Define imports and utility functions.\n",
        "\n",
        "import jax\n",
        "from jax.config import config as jax_config\n",
        "import jax.numpy as jnp\n",
        "from jax import grad, jit, vmap\n",
        "from jax import random\n",
        "\n",
        "import flax\n",
        "import flax.linen as nn\n",
        "from flax import jax_utils\n",
        "from flax import optim\n",
        "from flax.metrics import tensorboard\n",
        "from flax.training import checkpoints\n",
        "jax_config.enable_omnistaging() # Linen requires enabling omnistaging\n",
        "\n",
        "from absl import logging\n",
        "from io import BytesIO\n",
        "import random as pyrandom\n",
        "import numpy as np\n",
        "import PIL\n",
        "import IPython\n",
        "\n",
        "\n",
        "# Monkey patch logging.\n",
        "def myprint(msg, *args, **kwargs):\n",
        " print(msg % args)\n",
        "\n",
        "logging.info = myprint \n",
        "logging.warn = myprint\n",
        "logging.error = myprint\n",
        "\n",
        "\n",
        "def show_image(image, fmt='png'):\n",
        "    image = image_utils.image_to_uint8(image)\n",
        "    f = BytesIO()\n",
        "    PIL.Image.fromarray(image).save(f, fmt)\n",
        "    IPython.display.display(IPython.display.Image(data=f.getvalue()))\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wW7FsSB-jORB"
      },
      "source": [
        "## Configuration"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rz7wRm7YT9Ka"
      },
      "source": [
        "# @title Model and dataset configuration\n",
        "\n",
        "from pathlib import Path\n",
        "from pprint import pprint\n",
        "import gin\n",
        "from IPython.display import display, Markdown\n",
        "\n",
        "from hypernerf import models\n",
        "from hypernerf import modules\n",
        "from hypernerf import warping\n",
        "from hypernerf import datasets\n",
        "from hypernerf import configs\n",
        "\n",
        "\n",
        "# @markdown The working directory.\n",
        "train_dir = '/content/gdrive/My Drive/nerfies/hypernerf_experiments/capture1/exp1'  # @param {type: \"string\"}\n",
        "# @markdown The directory to the dataset capture.\n",
        "data_dir = '/content/gdrive/My Drive/nerfies/captures/capture1'  # @param {type: \"string\"}\n",
        "\n",
        "# @markdown Training configuration.\n",
        "max_steps = 100000  # @param {type: 'number'}\n",
        "batch_size = 4096  # @param {type: 'number'}\n",
        "image_scale = 8  # @param {type: 'number'}\n",
        "\n",
        "# @markdown Model configuration.\n",
        "use_viewdirs = True  #@param {type: 'boolean'}\n",
        "use_appearance_metadata = True  #@param {type: 'boolean'}\n",
        "num_coarse_samples = 64  # @param {type: 'number'}\n",
        "num_fine_samples = 64  # @param {type: 'number'}\n",
        "\n",
        "# @markdown Deformation configuration.\n",
        "use_warp = True  #@param {type: 'boolean'}\n",
        "warp_field_type = '@SE3Field'  #@param['@SE3Field', '@TranslationField']\n",
        "warp_min_deg = 0  #@param{type:'number'}\n",
        "warp_max_deg = 6  #@param{type:'number'}\n",
        "\n",
        "# @markdown Hyper-space configuration.\n",
        "hyper_num_dims = 8  #@param{type:'number'}\n",
        "hyper_point_min_deg = 0  #@param{type:'number'}\n",
        "hyper_point_max_deg = 1  #@param{type:'number'}\n",
        "hyper_slice_method = 'bendy_sheet'  #@param['none', 'axis_aligned_plane', 'bendy_sheet']\n",
        "\n",
        "\n",
        "checkpoint_dir = Path(train_dir, 'checkpoints')\n",
        "checkpoint_dir.mkdir(exist_ok=True, parents=True)\n",
        "\n",
        "config_str = f\"\"\"\n",
        "DELAYED_HYPER_ALPHA_SCHED = {{\n",
        "  'type': 'piecewise',\n",
        "  'schedules': [\n",
        "    (1000, ('constant', 0.0)),\n",
        "    (0, ('linear', 0.0, %hyper_point_max_deg, 10000))\n",
        "  ],\n",
        "}}\n",
        "\n",
        "ExperimentConfig.image_scale = {image_scale}\n",
        "ExperimentConfig.datasource_cls = @NerfiesDataSource\n",
        "NerfiesDataSource.data_dir = '{data_dir}'\n",
        "NerfiesDataSource.image_scale = {image_scale}\n",
        "\n",
        "NerfModel.use_viewdirs = {int(use_viewdirs)}\n",
        "NerfModel.use_rgb_condition = {int(use_appearance_metadata)}\n",
        "NerfModel.num_coarse_samples = {num_coarse_samples}\n",
        "NerfModel.num_fine_samples = {num_fine_samples}\n",
        "\n",
        "NerfModel.use_viewdirs = True\n",
        "NerfModel.use_stratified_sampling = True\n",
        "NerfModel.use_posenc_identity = False\n",
        "NerfModel.nerf_trunk_width = 128\n",
        "NerfModel.nerf_trunk_depth = 8\n",
        "\n",
        "TrainConfig.max_steps = {max_steps}\n",
        "TrainConfig.batch_size = {batch_size}\n",
        "TrainConfig.print_every = 100\n",
        "TrainConfig.use_elastic_loss = False\n",
        "TrainConfig.use_background_loss = False\n",
        "\n",
        "# Warp configs.\n",
        "warp_min_deg = {warp_min_deg}\n",
        "warp_max_deg = {warp_max_deg}\n",
        "NerfModel.use_warp = {use_warp}\n",
        "SE3Field.min_deg = %warp_min_deg\n",
        "SE3Field.max_deg = %warp_max_deg\n",
        "SE3Field.use_posenc_identity = False\n",
        "NerfModel.warp_field_cls = @SE3Field\n",
        "\n",
        "TrainConfig.warp_alpha_schedule = {{\n",
        "    'type': 'linear',\n",
        "    'initial_value': {warp_min_deg},\n",
        "    'final_value': {warp_max_deg},\n",
        "    'num_steps': {int(max_steps*0.8)},\n",
        "}}\n",
        "\n",
        "# Hyper configs.\n",
        "hyper_num_dims = {hyper_num_dims}\n",
        "hyper_point_min_deg = {hyper_point_min_deg}\n",
        "hyper_point_max_deg = {hyper_point_max_deg}\n",
        "\n",
        "NerfModel.hyper_embed_cls = @hyper/GLOEmbed\n",
        "hyper/GLOEmbed.num_dims = %hyper_num_dims\n",
        "NerfModel.hyper_point_min_deg = %hyper_point_min_deg\n",
        "NerfModel.hyper_point_max_deg = %hyper_point_max_deg\n",
        "\n",
        "TrainConfig.hyper_alpha_schedule = %DELAYED_HYPER_ALPHA_SCHED\n",
        "\n",
        "hyper_sheet_min_deg = 0\n",
        "hyper_sheet_max_deg = 6\n",
        "HyperSheetMLP.min_deg = %hyper_sheet_min_deg\n",
        "HyperSheetMLP.max_deg = %hyper_sheet_max_deg\n",
        "HyperSheetMLP.output_channels = %hyper_num_dims\n",
        "\n",
        "NerfModel.hyper_slice_method = '{hyper_slice_method}'\n",
        "NerfModel.hyper_sheet_mlp_cls = @HyperSheetMLP\n",
        "NerfModel.hyper_use_warp_embed = True\n",
        "\n",
        "TrainConfig.hyper_sheet_alpha_schedule = ('constant', %hyper_sheet_max_deg)\n",
        "\"\"\"\n",
        "\n",
        "gin.parse_config(config_str)\n",
        "\n",
        "config_path = Path(train_dir, 'config.gin')\n",
        "with open(config_path, 'w') as f:\n",
        "  logging.info('Saving config to %s', config_path)\n",
        "  f.write(config_str)\n",
        "\n",
        "exp_config = configs.ExperimentConfig()\n",
        "train_config = configs.TrainConfig()\n",
        "eval_config = configs.EvalConfig()\n",
        "\n",
        "display(Markdown(\n",
        "    gin.config.markdown(gin.config_str())))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r872r6hiVUVS",
        "cellView": "form"
      },
      "source": [
        "# @title Create datasource and show an example.\n",
        "\n",
        "from hypernerf import datasets\n",
        "from hypernerf import image_utils\n",
        "\n",
        "dummy_model = models.NerfModel({}, 0, 0)\n",
        "datasource = exp_config.datasource_cls(\n",
        "    image_scale=exp_config.image_scale,\n",
        "    random_seed=exp_config.random_seed,\n",
        "    # Enable metadata based on model needs.\n",
        "    use_warp_id=dummy_model.use_warp,\n",
        "    use_appearance_id=(\n",
        "        dummy_model.nerf_embed_key == 'appearance'\n",
        "        or dummy_model.hyper_embed_key == 'appearance'),\n",
        "    use_camera_id=dummy_model.nerf_embed_key == 'camera',\n",
        "    use_time=dummy_model.warp_embed_key == 'time')\n",
        "\n",
        "show_image(datasource.load_rgb(datasource.train_ids[0]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XC3PIY74XB05",
        "cellView": "form"
      },
      "source": [
        "# @title Create training iterators\n",
        "\n",
        "devices = jax.local_devices()\n",
        "\n",
        "train_iter = datasource.create_iterator(\n",
        "    datasource.train_ids,\n",
        "    flatten=True,\n",
        "    shuffle=True,\n",
        "    batch_size=train_config.batch_size,\n",
        "    prefetch_size=3,\n",
        "    shuffle_buffer_size=train_config.shuffle_buffer_size,\n",
        "    devices=devices,\n",
        ")\n",
        "\n",
        "def shuffled(l):\n",
        "  import random as r\n",
        "  import copy\n",
        "  l = copy.copy(l)\n",
        "  r.shuffle(l)\n",
        "  return l\n",
        "\n",
        "train_eval_iter = datasource.create_iterator(\n",
        "    shuffled(datasource.train_ids), batch_size=0, devices=devices)\n",
        "val_eval_iter = datasource.create_iterator(\n",
        "    shuffled(datasource.val_ids), batch_size=0, devices=devices)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "erY9l66KjYYW"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nZnS8BhcXe5E",
        "cellView": "form"
      },
      "source": [
        "# @title Initialize model\n",
        "# @markdown Defines the model and initializes its parameters.\n",
        "\n",
        "from flax.training import checkpoints\n",
        "from hypernerf import models\n",
        "from hypernerf import model_utils\n",
        "from hypernerf import schedules\n",
        "from hypernerf import training\n",
        "\n",
        "# @markdown Restore a checkpoint if one exists.\n",
        "restore_checkpoint = False  # @param{type:'boolean'}\n",
        "\n",
        "\n",
        "rng = random.PRNGKey(exp_config.random_seed)\n",
        "np.random.seed(exp_config.random_seed + jax.process_index())\n",
        "devices_to_use = jax.devices()\n",
        "\n",
        "learning_rate_sched = schedules.from_config(train_config.lr_schedule)\n",
        "nerf_alpha_sched = schedules.from_config(train_config.nerf_alpha_schedule)\n",
        "warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)\n",
        "elastic_loss_weight_sched = schedules.from_config(\n",
        "train_config.elastic_loss_weight_schedule)\n",
        "hyper_alpha_sched = schedules.from_config(train_config.hyper_alpha_schedule)\n",
        "hyper_sheet_alpha_sched = schedules.from_config(\n",
        "    train_config.hyper_sheet_alpha_schedule)\n",
        "\n",
        "rng, key = random.split(rng)\n",
        "params = {}\n",
        "model, params['model'] = models.construct_nerf(\n",
        "      key,\n",
        "      batch_size=train_config.batch_size,\n",
        "      embeddings_dict=datasource.embeddings_dict,\n",
        "      near=datasource.near,\n",
        "      far=datasource.far)\n",
        "\n",
        "optimizer_def = optim.Adam(learning_rate_sched(0))\n",
        "optimizer = optimizer_def.create(params)\n",
        "\n",
        "state = model_utils.TrainState(\n",
        "    optimizer=optimizer,\n",
        "    nerf_alpha=nerf_alpha_sched(0),\n",
        "    warp_alpha=warp_alpha_sched(0),\n",
        "    hyper_alpha=hyper_alpha_sched(0),\n",
        "    hyper_sheet_alpha=hyper_sheet_alpha_sched(0))\n",
        "scalar_params = training.ScalarParams(\n",
        "    learning_rate=learning_rate_sched(0),\n",
        "    elastic_loss_weight=elastic_loss_weight_sched(0),\n",
        "    warp_reg_loss_weight=train_config.warp_reg_loss_weight,\n",
        "    warp_reg_loss_alpha=train_config.warp_reg_loss_alpha,\n",
        "    warp_reg_loss_scale=train_config.warp_reg_loss_scale,\n",
        "    background_loss_weight=train_config.background_loss_weight,\n",
        "    hyper_reg_loss_weight=train_config.hyper_reg_loss_weight)\n",
        "\n",
        "if restore_checkpoint:\n",
        "  logging.info('Restoring checkpoint from %s', checkpoint_dir)\n",
        "  state = checkpoints.restore_checkpoint(checkpoint_dir, state)\n",
        "step = state.optimizer.state.step + 1\n",
        "state = jax_utils.replicate(state, devices=devices)\n",
        "del params"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "at2CL5DRZ7By",
        "cellView": "form"
      },
      "source": [
        "# @title Define pmapped functions\n",
        "# @markdown This parallelizes the training and evaluation step functions using `jax.pmap`.\n",
        "\n",
        "import functools\n",
        "from hypernerf import evaluation\n",
        "\n",
        "\n",
        "def _model_fn(key_0, key_1, params, rays_dict, extra_params):\n",
        "  out = model.apply({'params': params},\n",
        "                    rays_dict,\n",
        "                    extra_params=extra_params,\n",
        "                    rngs={\n",
        "                        'coarse': key_0,\n",
        "                        'fine': key_1\n",
        "                    },\n",
        "                    mutable=False)\n",
        "  return jax.lax.all_gather(out, axis_name='batch')\n",
        "\n",
        "pmodel_fn = jax.pmap(\n",
        "    # Note rng_keys are useless in eval mode since there's no randomness.\n",
        "    _model_fn,\n",
        "    in_axes=(0, 0, 0, 0, 0),  # Only distribute the data input.\n",
        "    devices=devices_to_use,\n",
        "    axis_name='batch',\n",
        ")\n",
        "\n",
        "render_fn = functools.partial(evaluation.render_image,\n",
        "                              model_fn=pmodel_fn,\n",
        "                              device_count=len(devices),\n",
        "                              chunk=eval_config.chunk)\n",
        "train_step = functools.partial(\n",
        "    training.train_step,\n",
        "    model,\n",
        "    elastic_reduce_method=train_config.elastic_reduce_method,\n",
        "    elastic_loss_type=train_config.elastic_loss_type,\n",
        "    use_elastic_loss=train_config.use_elastic_loss,\n",
        "    use_background_loss=train_config.use_background_loss,\n",
        "    use_warp_reg_loss=train_config.use_warp_reg_loss,\n",
        "    use_hyper_reg_loss=train_config.use_hyper_reg_loss,\n",
        ")\n",
        "ptrain_step = jax.pmap(\n",
        "    train_step,\n",
        "    axis_name='batch',\n",
        "    devices=devices,\n",
        "    # rng_key, state, batch, scalar_params.\n",
        "    in_axes=(0, 0, 0, None),\n",
        "    # Treat use_elastic_loss as compile-time static.\n",
        "    donate_argnums=(2,),  # Donate the 'batch' argument.\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vbc7cMr5aR_1",
        "cellView": "form"
      },
      "source": [
        "# @title Train!\n",
        "# @markdown This runs the training loop!\n",
        "\n",
        "import mediapy\n",
        "from hypernerf import utils\n",
        "from hypernerf import visualization as viz\n",
        "\n",
        "\n",
        "print_every_n_iterations = 100  # @param{type:'number'}\n",
        "visualize_results_every_n_iterations = 500  # @param{type:'number'}\n",
        "save_checkpoint_every_n_iterations = 1000  # @param{type:'number'}\n",
        "\n",
        "\n",
        "logging.info('Starting training')\n",
        "rng = rng + jax.process_index()  # Make random seed separate across hosts.\n",
        "keys = random.split(rng, len(devices))\n",
        "time_tracker = utils.TimeTracker()\n",
        "time_tracker.tic('data', 'total')\n",
        "\n",
        "for step, batch in zip(range(step, train_config.max_steps + 1), train_iter):\n",
        "  time_tracker.toc('data')\n",
        "  scalar_params = scalar_params.replace(\n",
        "      learning_rate=learning_rate_sched(step),\n",
        "      elastic_loss_weight=elastic_loss_weight_sched(step))\n",
        "  # pytype: enable=attribute-error\n",
        "  nerf_alpha = jax_utils.replicate(nerf_alpha_sched(step), devices)\n",
        "  warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices)\n",
        "  hyper_alpha = jax_utils.replicate(hyper_alpha_sched(step), devices)\n",
        "  hyper_sheet_alpha = jax_utils.replicate(\n",
        "      hyper_sheet_alpha_sched(step), devices)\n",
        "  state = state.replace(nerf_alpha=nerf_alpha,\n",
        "                        warp_alpha=warp_alpha,\n",
        "                        hyper_alpha=hyper_alpha,\n",
        "                        hyper_sheet_alpha=hyper_sheet_alpha)\n",
        "\n",
        "  with time_tracker.record_time('train_step'):\n",
        "    state, stats, keys, _ = ptrain_step(keys, state, batch, scalar_params)\n",
        "    time_tracker.toc('total')\n",
        "\n",
        "  if step % print_every_n_iterations == 0:\n",
        "    logging.info(\n",
        "        'step=%d, warp_alpha=%.04f, hyper_alpha=%.04f, hyper_sheet_alpha=%.04f, %s',\n",
        "        step, \n",
        "        warp_alpha_sched(step), \n",
        "        hyper_alpha_sched(step), \n",
        "        hyper_sheet_alpha_sched(step), \n",
        "        time_tracker.summary_str('last'))\n",
        "    coarse_metrics_str = ', '.join(\n",
        "        [f'{k}={v.mean():.04f}' for k, v in stats['coarse'].items()])\n",
        "    fine_metrics_str = ', '.join(\n",
        "        [f'{k}={v.mean():.04f}' for k, v in stats['fine'].items()])\n",
        "    logging.info('\\tcoarse metrics: %s', coarse_metrics_str)\n",
        "    if 'fine' in stats:\n",
        "      logging.info('\\tfine metrics: %s', fine_metrics_str)\n",
        "  \n",
        "  if step % visualize_results_every_n_iterations == 0:\n",
        "    print(f'[step={step}] Training set visualization')\n",
        "    eval_batch = next(train_eval_iter)\n",
        "    render = render_fn(state, eval_batch, rng=rng)\n",
        "    rgb = render['rgb']\n",
        "    acc = render['acc']\n",
        "    depth_exp = render['depth']\n",
        "    depth_med = render['med_depth']\n",
        "    rgb_target = eval_batch['rgb']\n",
        "    depth_med_viz = viz.colorize(depth_med, cmin=datasource.near, cmax=datasource.far)\n",
        "    mediapy.show_images([rgb_target, rgb, depth_med_viz],\n",
        "                        titles=['GT RGB', 'Pred RGB', 'Pred Depth'])\n",
        "\n",
        "    print(f'[step={step}] Validation set visualization')\n",
        "    eval_batch = next(val_eval_iter)\n",
        "    render = render_fn(state, eval_batch, rng=rng)\n",
        "    rgb = render['rgb']\n",
        "    acc = render['acc']\n",
        "    depth_exp = render['depth']\n",
        "    depth_med = render['med_depth']\n",
        "    rgb_target = eval_batch['rgb']\n",
        "    depth_med_viz = viz.colorize(depth_med, cmin=datasource.near, cmax=datasource.far)\n",
        "    mediapy.show_images([rgb_target, rgb, depth_med_viz],\n",
        "                       titles=['GT RGB', 'Pred RGB', 'Pred Depth'])\n",
        "\n",
        "  if step % save_checkpoint_every_n_iterations == 0:\n",
        "    training.save_checkpoint(checkpoint_dir, state)\n",
        "\n",
        "  time_tracker.tic('data', 'total')\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "o69auGWvdyyd"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}